"""Make figure showing association between gender and genre."""
import argparse
import os

import matplotlib.pyplot as plt
import matplotlib.style
import numpy as np
import pandas as pd
import tqdm
from sklearn import feature_extraction, feature_selection, linear_model, model_selection

import datasets_gender_genre

parser = argparse.ArgumentParser()
parser.add_argument("output_filename", help="Output path for figure.")


# use seaborn style globally
matplotlib.style.use("seaborn-deep")


def leave_one_out_logistic_loss(X, y, index):
    """Calculate leave-one-out logistic loss for observation at `index`.

    Return the log probability for the positive case.

    """
    X, y = np.asarray(X), np.asarray(y)
    clf = linear_model.LogisticRegression(C=1, fit_intercept=True, random_state=1)
    clf.fit(np.delete(X, index, axis=0), np.delete(y, index))
    return -1 * clf.predict_log_proba(X[index].reshape(1, -1))[0, y[index]]


def make_plot(output_filename):
    df, dtm, vocab = datasets_gender_genre.dataset()
    # exclude Unknown
    unknown_mask = df["gender"] == "Unknown"
    df, dtm = df.loc[~unknown_mask], dtm[~unknown_mask]
    vocab = np.array(vocab)

    # exclude any words which do not occur at least 2 times in new dtm
    below_min_df_mask = (dtm > 0).sum(axis=0) < 2
    dtm, vocab = dtm[:, ~below_min_df_mask], vocab[~below_min_df_mask]

    y = (df["gender"] == "Male").astype(int)
    chi2, pval = feature_selection.chi2(dtm, y)
    assert not np.isnan(chi2).any()
    # remove words which directly signal gender
    print("vocab size before pruning:", len(vocab))
    vocab_remove_ix = chi2 > 7
    print("vocab words removed", vocab[vocab_remove_ix])
    dtm, vocab = dtm[:, ~vocab_remove_ix], vocab[~vocab_remove_ix]
    print("vocab size after pruning:", len(vocab))

    # Important notes:
    # - decade dummies do not help much
    # - forcing back adventures, freebooter, and naval does not help that much
    #   with: X_decades = pd.get_dummies(pd.cut(df['year'], range(1800, 1830 + 1, 10), right=False))
    # - tfidf does not help
    # - fasttext does not help

    X = dtm

    # cross-validation accuracy
    clf = linear_model.LogisticRegression(fit_intercept=True)
    num_folds = 10
    scores = model_selection.cross_val_score(clf, X, y, cv=num_folds)
    print("null model accuracy:", (1 - y.mean()))
    print("scores:", scores)
    print(
        f"cross-validation accuracy ({num_folds} folds): %0.2f (+/- %0.2f) (2sd)"
        % (scores.mean(), scores.std() * 2)
    )

    loo_year_records = []
    for index in tqdm.tqdm(range(len(X))):
        logistic_loss = leave_one_out_logistic_loss(X, y, index)
        loo_year_records.append(
            {
                "gender": df.iloc[index]["gender"],
                "year": df.iloc[index]["year"],
                "logistic_loss": logistic_loss,
            }
        )
    plot_df = pd.DataFrame.from_records(loo_year_records)
    plot_df["correct"] = (plot_df["logistic_loss"] < -1 * np.log(0.5)).astype(int)
    assert plot_df["correct"].mean() > 0.5
    plot_df.to_csv("/tmp/plot.csv")
    # half-decade periods, smooths over year-to-year variation
    plot_df["period"] = pd.cut(plot_df["year"], range(1800, 1830 + 1, 5), right=False)
    plot_df["period"] = plot_df["period"].apply(
        lambda interval: f"{interval.left}-{interval.right - 1}"
    )
    plot_df["man_author"] = plot_df["gender"] == "Male"

    # for sensitivity, only consider man_author
    plot_df = plot_df.loc[plot_df["man_author"]]

    # pandas .plot method does not work very well
    fig, ax = plt.subplots()
    sensitivity = plot_df.groupby("period")["correct"].mean()
    print('Sensitivity:\n', sensitivity)
    plt.plot(np.arange(len(sensitivity.index)), sensitivity.values)
    plt.ylim((0, 1.0))
    plt.xticks(np.arange(len(sensitivity.index)), sensitivity.index)
    plt.ylabel("Sensitivity")
    # title added to caption
    #plt.title("Association Between Author Gender and Novel Title Words")
    plt.tight_layout()
    plt.savefig(output_filename)
    print(f"saved plot to `{output_filename}`")


if __name__ == "__main__":
    args = parser.parse_args()
    make_plot(args.output_filename)
